# Copyright (c) 2015 Uber Technologies, Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

"""
Provides functionality that is used exclusively at runtime.
"""
from __future__ import absolute_import, unicode_literals, print_function

from libc.stdint cimport int32_t

from thriftrw.protocol.core cimport Protocol
from thriftrw.wire cimport mtype
from thriftrw.wire cimport ttype
from thriftrw.wire.value cimport Value, StructValue
from thriftrw.spec.service import (
    FunctionArgsSpec,
    FunctionResultSpec,
)
from thriftrw.errors import (
    ThriftProtocolError,
    UnknownExceptionError,
)


cdef class Serializer(object):

    def __cinit__(self, Protocol protocol):
        self.protocol = protocol

    def __call__(self, obj):
        """Serializes the given Thrift struct.

        :param obj:
            An instance of a class generated by thriftrw representing a
            struct, union, or exception.
        :returns:
            Binary representation of the object.
        """
        return self.dumps(obj)

    cpdef bytes dumps(self, obj):
        cdef Value value = obj.__class__.type_spec.to_wire(obj)
        return self.protocol.serialize_value(value)

    cpdef bytes message(self, obj, int32_t seqid=0):
        """Serializes the given request or response into a Thrift Message.

        :param obj:
            Request or response to serialize. This is constructed using the
            ``request`` or ``response`` attribute on a
            :py:class:`thriftrw.spec.ServiceFunction`.
        :param int seqid:
            If given, this specifies the seqid to use for the message.
            Defaults to 0.
        """
        obj_spec = obj.__class__.type_spec

        if isinstance(obj_spec, FunctionArgsSpec):
            function_spec = obj_spec.function
            if function_spec.oneway:
                message_type = mtype.ONEWAY
            else:
                message_type = mtype.CALL
        elif isinstance(obj_spec, FunctionResultSpec):
            function_spec = obj_spec.function
            message_type = mtype.REPLY
        else:
            raise TypeError(
                'Only function request or response types may be wrapped '
                'in messages.'
            )

        cdef StructValue body = obj_spec.to_wire(obj)
        cdef Message message = Message(
            function_spec.name, seqid, message_type, body
        )

        return self.protocol.serialize_message(message)


cdef class Deserializer(object):

    def __cinit__(self, Protocol protocol):
        self.protocol = protocol

    def __call__(self, obj_cls, bytes s):
        """Deserializes an object from the given blob.

        :param obj_cls:
            A class generated by thriftrw representing a struct, union, or
            exception.
        :param bytes s:
            Binary blob representing the object.
        :returns:
            Deserialized object.
        :raises thriftrw.errors.ThriftProtocolError:
            If the object failed to deserialize.
        """
        return self.loads(obj_cls, s)

    cpdef object loads(self, obj_cls, bytes s):
        cdef Value value = self.protocol.deserialize_value(ttype.STRUCT, s)
        return obj_cls.type_spec.from_wire(value)

    cpdef Message message(self, service, bytes s):
        """Deserializes a message from the given blob.

        :param service:
            Reference to a service class. The request or response for one of
            the methods of this service will be read from the blob based on
            the message type.
        :param s:
            Binary blob representing the message and its payload.
        :returns:
            A Message containing the parsed request or response object in the
            ``body``.
        :raises UnknownExceptionError:
            If a ``EXCEPTION`` message was parsed.
        :raises ThriftProtocolError:
            If the method name is not recognized or if any other parsing error
            occurs.
        """
        service_spec = service.service_spec

        cdef Message message = self.protocol.deserialize_message(s)

        if message.message_type == mtype.EXCEPTION:
            # For EXCEPTION messages, just raise UnknownExceptionError with
            # the struct representation in the message.
            raise UnknownExceptionError('Received an exception message.', message.body)

        function_spec = service_spec.lookup(message.name)
        if function_spec is None:
            raise ThriftProtocolError(
                'Unknown method "%s" referenced my message %r'
                % (message.name, message)
            )

        if (
            message.message_type == mtype.CALL or
            message.message_type == mtype.ONEWAY
        ):
            message.body = function_spec.args_spec.from_wire(message.body)
        elif message.message_type == mtype.REPLY:
            if function_spec.oneway:
                raise ThriftProtocolError(
                    'Function "%s" is a oneway method. '
                    'It cannot receive a REPLY.' % function_spec.name
                )

            message.body = function_spec.result_spec.from_wire(message.body)
        else:
            # Unrecognized message type. If this happens, we have a bug
            # because deserialize_message already raises an exception for
            # invalid message IDs.
            raise ValueError(
                'Unknown message type %d in message %r'
                % (message.message_type, message)
            )

        return message
